# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements.  See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership.  The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License.  You may obtain a copy of the License at
#
#   http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied.  See the License for the
# specific language governing permissions and limitations
# under the License.

library(dplyr, warn.conflicts = FALSE)

skip_if_not_available("acero")

tbl <- example_data

test_that("slice_head/tail, ungrouped", {
  # head/tail are not deterministic in Arrow because data is unordered
  # so we can't assert identical to dplyr, just assert right number of rows
  tab <- arrow_table(tbl)
  expect_equal(
    tab %>%
      slice_head(n = 5) %>%
      nrow(),
    5
  )
  expect_equal(
    tab %>%
      slice_tail(n = 5) %>%
      nrow(),
    5
  )

  expect_equal(
    tab %>%
      slice_head(prop = 0.25) %>%
      nrow(),
    2
  )
  expect_equal(
    tab %>%
      slice_tail(prop = 0.25) %>%
      nrow(),
    2
  )
})

test_that("slice_min/max, ungrouped", {
  # with_ties must be FALSE
  tab <- arrow_table(tbl)
  expect_error(
    tab %>% slice_max(int, n = 5),
    "with_ties = TRUE"
  )
  expect_error(
    tab %>% slice_min(int, n = 5),
    "with_ties = TRUE"
  )
  compare_dplyr_binding(
    .input %>%
      slice_max(int, n = 4, with_ties = FALSE) %>%
      collect(),
    tbl
  )
  compare_dplyr_binding(
    .input %>%
      slice_min(int, n = 4, with_ties = FALSE) %>%
      collect(),
    tbl
  )

  compare_dplyr_binding(
    .input %>%
      slice_max(int, prop = 0.25, with_ties = FALSE) %>%
      collect(),
    tbl
  )
  compare_dplyr_binding(
    .input %>%
      slice_min(int, prop = 0.25, with_ties = FALSE) %>%
      collect(),
    tbl
  )
})

test_that("slice_sample, ungrouped", {
  skip_if_not(CanRunWithCapturedR())

  tab <- arrow_table(tbl)
  expect_error(
    tab %>% slice_sample(replace = TRUE),
    "Sampling with replacement"
  )
  expect_error(
    tab %>% slice_sample(weight_by = dbl),
    "weight_by"
  )

  # Let's not take any chances on random failures
  skip_on_cran()
  # Because this is random (and we only have 10 rows), try several times
  for (i in 1:50) {
    sampled_prop <- tab %>%
      slice_sample(prop = 0.2) %>%
      collect() %>%
      nrow()
    if (sampled_prop == 2) break
  }
  expect_equal(sampled_prop, 2)

  # Test that slice_sample(n) returns n rows
  # With a larger dataset, we would be more confident to get exactly n
  # but with this dataset, we should at least not get >n rows
  sampled_n <- tab %>%
    slice_sample(n = 2) %>%
    collect() %>%
    nrow()
  expect_lte(sampled_n, 2)

  # Test with dataset, which matters for the UDF HACK
  skip_if_not_available("dataset")
  sampled_n <- tab %>%
    InMemoryDataset$create() %>%
    slice_sample(n = 2) %>%
    collect() %>%
    nrow()
  expect_lte(sampled_n, 2)
})

test_that("slice_* not supported with groups", {
  grouped <- tbl %>%
    arrow_table() %>%
    group_by(lgl)
  expect_error(
    slice_head(grouped, n = 5),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_tail(grouped, n = 5),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_min(grouped, int, n = 5),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_max(grouped, int, n = 5),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_sample(grouped, n = 5),
    "Slicing grouped data not supported in Arrow"
  )

  # with the by argument
  expect_error(
    slice_head(arrow_table(tbl), n = 5, by = lgl),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_tail(arrow_table(tbl), n = 5, by = lgl),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_min(arrow_table(tbl), int, n = 5, by = lgl),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_max(arrow_table(tbl), int, n = 5, by = lgl),
    "Slicing grouped data not supported in Arrow"
  )
  expect_error(
    slice_sample(arrow_table(tbl), n = 5, by = lgl),
    "Slicing grouped data not supported in Arrow"
  )
})

test_that("input validation", {
  tab <- arrow_table(tbl)
  for (p in list("a", -1, 2, c(0.01, 0.02), NA_real_)) {
    expect_error(
      slice_head(tab, prop = !!p),
      "`prop` must be a single numeric value between 0 and 1",
      fixed = TRUE
    )
  }

  expect_error(
    tab %>% slice_tail(n = 3, with_ties = FALSE),
    "`...` must be empty"
  )
})

test_that("n <-> prop conversion when nrow is not known", {
  joined <- tbl %>%
    arrow_table() %>%
    full_join(tbl, by = "int")
  expect_true(is.na(nrow(joined)))

  expect_error(
    joined %>%
      slice_min(int, prop = 0.25, with_ties = FALSE),
    "Slicing with `prop` when"
  )

  expect_error(
    joined %>%
      slice_sample(n = 5),
    "slice_sample() with `n` when",
    fixed = TRUE
  )
})

# TODO: handle edge case where prop = 1, do nothing?
